import os
import gc
import random
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.utils.data as Data
from torch.utils.data.dataset import TensorDataset
import torchdiffeq

def load_data(path, batch_size = 128):
    t_train = torch.FloatTensor(np.load(os.path.join(path, 'train_t.npy')))
    feat_train = torch.FloatTensor(np.load(os.path.join(path, 'train_x.npy')))
    edges_train = torch.FloatTensor(np.load(os.path.join(path, 'train_y.npy')))

    t_test = torch.FloatTensor(np.load(os.path.join(path, 'test_t.npy')))
    feat_test = torch.FloatTensor(np.load(os.path.join(path, 'test_x.npy')))
    edges_test = torch.FloatTensor(np.load(os.path.join(path, 'test_y.npy')))

    train_data = TensorDataset(t_train, feat_train, edges_train)
    test_data = TensorDataset(t_test, feat_test, edges_test)

    train_data_loader = Data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_data_loader = Data.DataLoader(test_data, batch_size=batch_size)

    return train_data_loader, test_data_loader

class VectorField(nn.Module):
    def __init__(self, dX_dt, func):
        """Defines a controlled vector field.

        Arguments:
            dX_dt: As cdeint.
            func: As cdeint.
        """
        super(VectorField, self).__init__()
        if not isinstance(func, nn.Module):
            raise ValueError("func must be a nn.Module.")

        self.dX_dt = dX_dt
        self.func = func

    def __call__(self, t, z):
        # control_gradient is of shape (..., input_channels)
        control_gradient = self.dX_dt(t)
        # vector_field is of shape (..., hidden_channels, input_channels)
        vector_field = self.func(z)
        # out is of shape (..., hidden_channels)
        # (The squeezing is necessary to make the matrix-multiply properly batch in all cases)
        out = (vector_field @ control_gradient.unsqueeze(-1)).squeeze(-1)
        return out

class NaturalCubicSpline:
    def __init__(self, times, coeffs, **kwargs):
        super(NaturalCubicSpline, self).__init__(**kwargs)
        (a, b, two_c, three_d) = coeffs
        self._times = times # times.shape == (batch_size, n_take)
        self._a = a
        self._b = b
        # as we're typically computing derivatives, we store the multiples of these coefficients that are more useful
        self._two_c = two_c
        self._three_d = three_d
        self.range = torch.arange(0, self._times.size(0))

    def _interpret_t(self, t):
        maxlen = self._b.size(-2) - 1
        index = (t > self._times).sum(dim=1) - 1 # index.size == (batch_size)
        index = index.clamp(0, maxlen)  # clamp because t may go outside of [t[0], t[-1]]; this is fine
        # will never access the last element of self._times; this is correct behaviour
        fractional_part = t - self._times[self.range, index]
        return fractional_part.unsqueeze(dim=1), index

    def evaluate(self, t):
        """Evaluates the natural cubic spline interpolation at a point t, which should be a scalar tensor."""
        fractional_part, index = self._interpret_t(t)
        inner = 0.5 * self._two_c[self.range, index, :] + self._three_d[self.range, index, :] * fractional_part / 3
        inner = self._b[self.range, index, :] + inner * fractional_part
        return self._a[self.range, index, :] + inner * fractional_part

    def derivative(self, t):
        """Evaluates the derivative of the natural cubic spline at a point t, which should be a scalar tensor."""
        fractional_part, index = self._interpret_t(t)
        inner = self._two_c[self.range, index, :] + self._three_d[self.range, index, :] * fractional_part
        deriv = self._b[self.range, index, :] + inner * fractional_part
        return deriv

class CDEFunc(nn.Module):
    def __init__(self, input_channels, hidden_channels):
        super(CDEFunc, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.linear1 = nn.Linear(hidden_channels, hidden_channels)
        self.linear2 = nn.Linear(hidden_channels, input_channels * hidden_channels)

    def forward(self, z):
        # z.shape == (n_blocks, hidden_channels)
        z = self.linear1(z)
        z = z.relu()
        z = self.linear2(z)

        z = z.tanh()

        z = z.view(*z.shape[:-1], self.hidden_channels, self.input_channels) # z.shape == (n_blocks, hidden_size, input_size)
        return z

class NeuralCDE(nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels):
        super(NeuralCDE, self).__init__()
        self.hidden_channels = hidden_channels

        self.func = CDEFunc(input_channels, hidden_channels)
        self.initial = nn.Linear(input_channels, hidden_channels)
        self.readout = nn.Linear(hidden_channels, output_channels)

    def forward(self, t, x):

        spline = NaturalCubicSpline(t, cubic_spline(t, x))

        z0 = self.initial(spline.evaluate(t[0, 0]))

        vector_field = VectorField(dX_dt=spline.derivative, func=self.func)

        z_T = torchdiffeq.odeint_adjoint(func=vector_field, y0=z0, t=t[0, [0, -1]], atol=1e-2, rtol=1e-2)

        return self.readout(z_T[1])

def tridiagonal_solve(b_, A_upper_, A_diagonal_, A_lower_):
    """Solves a tridiagonal system Ax = b.

    The arguments A_upper, A_digonal, A_lower correspond to the three diagonals of A. Letting U = A_upper, D=A_digonal
    and L = A_lower, and assuming for simplicity that there are no batch dimensions, then the matrix A is assumed to be
    of size (k, k), with entries:

    D[0] U[0]
    L[0] D[1] U[1]
         L[1] D[2] U[2]                     0
              L[2] D[3] U[3]
                  .    .    .
                       .      .      .
                           .        .        .
                        L[k - 3] D[k - 2] U[k - 2]
           0                     L[k - 2] D[k - 1] U[k - 1]
                                          L[k - 1]   D[k]

    Arguments:
        b: A tensor of shape (..., k), where '...' is zero or more batch dimensions
        A_upper: A tensor of shape (..., k - 1).
        A_diagonal: A tensor of shape (..., k).
        A_lower: A tensor of shape (..., k - 1).

    Returns:
        A tensor of shape (..., k), corresponding to the x solving Ax = b

    Warning:
        This implementation isn't super fast. You probably want to cache the result, if possible.
    """

    # This implementation is very much written for clarity rather than speed.

    A_upper = torch.empty(b_.size(0), b_.size(1), b_.size(2) - 1, dtype=b_.dtype, device=b_.device)
    A_lower = torch.empty(b_.size(0), b_.size(1), b_.size(2) - 1, dtype=b_.dtype, device=b_.device)
    A_diagonal = torch.empty(*b_.shape, dtype=b_.dtype, device=b_.device)
    b = torch.empty(*b_.shape, dtype=b_.dtype, device=b_.device)

    for i in range(b_.size(0)):
        A_upper[i], _ = torch.broadcast_tensors(A_upper_[i], b_[i, :, :-1])
        A_lower[i], _ = torch.broadcast_tensors(A_lower_[i], b_[i, :, :-1])
        A_diagonal[i], b[i] = torch.broadcast_tensors(A_diagonal_[i], b_[i])

    channels = b.size(-1)

    new_shape = (b.size(0), channels, b.size(1))
    new_b = torch.zeros(*new_shape, dtype=b.dtype, device=b_.device)
    new_A_diagonal = torch.empty(*new_shape, dtype=b.dtype, device=b_.device)
    outs = torch.empty(*new_shape, dtype=b.dtype, device=b_.device)
    
    new_b[:, 0] = b[..., 0]
    new_A_diagonal[:, 0] = A_diagonal[..., 0]
    for i in range(1, channels):
        w = A_lower[..., i - 1] / new_A_diagonal[:, i - 1]
        new_A_diagonal[:, i] = A_diagonal[..., i] - w * A_upper[..., i - 1]
        new_b[:, i] = b[..., i] - w * new_b[:, i - 1]

    outs[:, channels - 1] = new_b[:, channels - 1] / new_A_diagonal[:, channels - 1]
    for i in range(channels - 2, -1, -1):
        outs[:, i] = (new_b[:, i] - A_upper[..., i] * outs[:, i + 1]) / new_A_diagonal[:, i]

    return outs.permute(0, 2, 1)

def cubic_spline(times, x):
    path = x.transpose(-1, -2)
    length = path.size(-1)

    # Set up some intermediate values
    time_diffs = times[:, 1:] - times[:, :-1]
    time_diffs_reciprocal = time_diffs.reciprocal()
    time_diffs_reciprocal_squared = time_diffs_reciprocal ** 2

    three_path_diffs = 3 * (path[..., 1:] - path[..., :-1])
    six_path_diffs = 2 * three_path_diffs

    # path_diffs_scaled.shape == (batch_size, input_size, n_take)
    path_diffs_scaled = three_path_diffs * time_diffs_reciprocal_squared.unsqueeze(dim=1)

    # Solve a tridiagonal linear system to find the derivatives at the knots
    system_diagonal = torch.empty(times.size(0), length, dtype=path.dtype, device=path.device)
    system_diagonal[:, :-1] = time_diffs_reciprocal
    system_diagonal[:, -1] = 0
    system_diagonal[:, 1:] += time_diffs_reciprocal
    system_diagonal *= 2
    system_rhs = torch.empty(*path.shape, dtype=path.dtype, device=path.device)
    system_rhs[..., :-1] = path_diffs_scaled
    system_rhs[..., -1] = 0
    system_rhs[..., 1:] += path_diffs_scaled

    knot_derivatives = tridiagonal_solve(system_rhs, time_diffs_reciprocal, system_diagonal, time_diffs_reciprocal)

    a = path[..., :-1]
    b = knot_derivatives[..., :-1]
    two_c = (six_path_diffs * time_diffs_reciprocal.unsqueeze(dim=1)
            - 4 * knot_derivatives[..., :-1]
            - 2 * knot_derivatives[..., 1:]) * time_diffs_reciprocal.unsqueeze(dim=1)
    three_d = (-six_path_diffs * time_diffs_reciprocal.unsqueeze(dim=1)
            + 3 * (knot_derivatives[..., :-1]
                    + knot_derivatives[..., 1:])) * time_diffs_reciprocal_squared.unsqueeze(dim=1)

    return a.transpose(-1, -2), b.transpose(-1, -2), two_c.transpose(-1, -2), three_d.transpose(-1, -2)

    path = x.transpose(-1, -2)
    length = path.size(-1)

    if length < 2:
        # In practice this should always already be caught in __init__.
        raise ValueError("Must have a time dimension of size at least 2.")
    elif length == 2:
        a = path[..., :1]
        b = (path[..., 1:] - path[..., :1]) / (times[..., 1:] - times[..., :1])
        two_c = torch.zeros(*path.shape[:-1], 1, dtype=path.dtype, device=path.device)
        three_d = torch.zeros(*path.shape[:-1], 1, dtype=path.dtype, device=path.device)
    else:
        # Set up some intermediate values
        time_diffs = times[1:] - times[:-1]
        time_diffs_reciprocal = time_diffs.reciprocal()
        time_diffs_reciprocal_squared = time_diffs_reciprocal ** 2
        three_path_diffs = 3 * (path[..., 1:] - path[..., :-1])
        six_path_diffs = 2 * three_path_diffs
        path_diffs_scaled = three_path_diffs * time_diffs_reciprocal_squared

        # Solve a tridiagonal linear system to find the derivatives at the knots
        system_diagonal = torch.empty(length, dtype=path.dtype, device=path.device)
        system_diagonal[:-1] = time_diffs_reciprocal
        system_diagonal[-1] = 0
        system_diagonal[1:] += time_diffs_reciprocal
        system_diagonal *= 2
        system_rhs = torch.empty_like(path)
        system_rhs[..., :-1] = path_diffs_scaled
        system_rhs[..., -1] = 0
        system_rhs[..., 1:] += path_diffs_scaled
        knot_derivatives = tridiagonal_solve(system_rhs, time_diffs_reciprocal, system_diagonal, time_diffs_reciprocal)

        a = path[..., :-1]
        b = knot_derivatives[..., :-1]
        two_c = (six_path_diffs * time_diffs_reciprocal
                 - 4 * knot_derivatives[..., :-1]
                 - 2 * knot_derivatives[..., 1:]) * time_diffs_reciprocal
        three_d = (-six_path_diffs * time_diffs_reciprocal
                   + 3 * (knot_derivatives[..., :-1]
                          + knot_derivatives[..., 1:])) * time_diffs_reciprocal_squared

    return a.transpose(-1, -2), b.transpose(-1, -2), two_c.transpose(-1, -2), three_d.transpose(-1, -2)

def main(version, hidden_size):
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    num_epochs = 50
    batch_size = 128

    train_loader, test_loader = load_data('irregular_spring', batch_size)

    print(device)

    model = NeuralCDE(20, hidden_size, 20).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = num_epochs * len(train_loader), eta_min = 0.00005, last_epoch = -1)

    criterion = nn.BCEWithLogitsLoss()

    train_loss = []
    train_acc = []
    test_loss = []
    test_acc = []

    for epoch in range(num_epochs):
        print(version, 'Epoch {}/{}'.format(epoch + 1, num_epochs))
        epoch_loss = 0
        epoch_corrects = 0
        num_sample = 0
        model.train()
        for t, x, y in tqdm(train_loader):
            t, x, y = t.to(device), x.to(device), y.to(device)
            output = model(t, x)
            loss = criterion(output, y)
            loss.backward()

            epoch_corrects += int(torch.sum((output > 0).int() == y))
            epoch_loss += loss.item() * x.size(0)
            num_sample += x.size(0) * y.size(1)

            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()

        train_loss.append(epoch_loss / num_sample)
        train_acc.append(epoch_corrects / num_sample)
        print(' ', train_loss[-1], train_acc[-1])

        epoch_loss = 0
        epoch_corrects = 0
        num_sample = 0
        model.eval()
        with torch.no_grad():
            for t, x, y in tqdm(test_loader):
                t, x, y = t.to(device), x.to(device), y.to(device)
                output = model(t, x)
                loss = criterion(output, y)

                epoch_corrects += int(torch.sum((output > 0).int() == y))
                epoch_loss += loss.item() * x.size(0)
                num_sample += x.size(0) * y.size(1)

        test_loss.append(epoch_loss / num_sample)
        test_acc.append(epoch_corrects / num_sample)
        print(' ', test_loss[-1], test_acc[-1])

        torch.save(model, f'{version}.pkl')

        try:
            pd.DataFrame({'Train Loss': train_loss, 'Train Acc': train_acc}).to_csv(f'{version}_Train.csv')
        except:
            print('Fail to save the file Train.csv')
            pd.DataFrame({'Train Loss': train_loss, 'Train Acc': train_acc}).to_csv(f'{version}_Train_1.csv')

        try:
            pd.DataFrame({'Test Loss': test_loss, 'Test Acc': test_acc}).to_csv(f'{version}_Test.csv')
        except:
            print('Fail to save the file Test.csv')
            pd.DataFrame({'Test Loss': test_loss, 'Test Acc': test_acc}).to_csv(f'{version}_Test_1.csv')

    gc.collect()

if __name__ == '__main__':
    for hidden_size in [128, 256, 512]:
        file_version = f'CDE_IrrSpring_{hidden_size}'
        main(file_version, hidden_size)